{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import re\n",
    "\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ').replace('\\t', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin\n",
    "import sentencepiece as spm\n",
    "from glob import glob\n",
    "import os\n",
    "import json\n",
    "\n",
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.ms-en.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = '/home/husein/pure-text/'\n",
    "paws = ['paws-train.json.translated', 'paws-dev.json.translated',\n",
    "        'paws-test.json.translated']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43658 43658\n",
      "7078 7078\n",
      "7072 7072\n"
     ]
    }
   ],
   "source": [
    "X, Y = [], []\n",
    "for file in paws:\n",
    "    x, y = [], []\n",
    "    with open(os.path.join(directory, file)) as fopen:\n",
    "        for l in fopen:\n",
    "            l = json.loads(l)\n",
    "            x.append(l['sentence1_ms'])\n",
    "            y.append(l['sentence2_ms'])\n",
    "            x.append(l['sentence2_ms'])\n",
    "            y.append(l['sentence1_ms'])\n",
    "            \n",
    "    print(len(x), len(y))\n",
    "    X.extend(x)\n",
    "    Y.extend(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "parasci_src = ['train.src.truecase.translated', \n",
    "               'train.src.1.truecase.translated',\n",
    "               'val.src.truecase.translated',\n",
    "               'val.src.1.truecase.translated']\n",
    "parasci_tgt = ['train.tgt.truecase.translated', \n",
    "               'train.tgt.1.truecase.translated',\n",
    "               'val.src.truecase.translated',\n",
    "               'val.tgt.1.truecase.translated']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57766 57766\n",
      "275234 275234\n",
      "5506 5506\n",
      "7360 7360\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(parasci_src)):\n",
    "    file_src = os.path.join(directory, parasci_src[i])\n",
    "    file_tgt = os.path.join(directory, parasci_tgt[i])\n",
    "    with open(file_src) as fopen:\n",
    "        left = fopen.read().split('\\n')\n",
    "    with open(file_tgt) as fopen:\n",
    "        right = fopen.read().split('\\n')\n",
    "    if len(left) < len(right):\n",
    "        right = right[:len(left)]\n",
    "    else:\n",
    "        left = left[:len(right)]\n",
    "        \n",
    "    x, y = [], []\n",
    "    for k in range(len(left)):\n",
    "        if len(left[k]) and len(right[k]):\n",
    "            l_left = json.loads(left[k])\n",
    "            l_right = json.loads(right[k])\n",
    "            x.append(l_left['ms'])\n",
    "            y.append(l_right['ms'])\n",
    "            x.append(l_right['ms'])\n",
    "            y.append(l_left['ms'])\n",
    "            \n",
    "    print(len(x), len(y))\n",
    "    X.extend(x)\n",
    "    Y.extend(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(403674, 403674)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(X), len(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 403674/403674 [00:08<00:00, 49100.94it/s]\n"
     ]
    }
   ],
   "source": [
    "with tf.io.gfile.GFile('paraphrase.tsv', \"w\") as outfile:\n",
    "    for i in tqdm(range(len(X))):\n",
    "        if len(X) and len(Y):\n",
    "            l = cleaning(X[i])\n",
    "            r = cleaning(Y[i])\n",
    "            outfile.write(\"%s\\t%s\\n\" % (l, r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def paraphrase_dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset(\n",
    "        [\n",
    "            'paraphrase.tsv'\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['question', 'answer'], ex)))\n",
    "    return ds\n",
    "\n",
    "def paraphrase_preprocessor(ds):\n",
    "    def to_inputs_and_targets(ex):\n",
    "        return {\n",
    "            'inputs': tf.strings.join(['parafrasa: ', ex['question']]),\n",
    "            'targets': ex['answer'],\n",
    "        }\n",
    "\n",
    "    return ds.map(\n",
    "        to_inputs_and_targets,\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "t5.data.TaskRegistry.remove('paraphrase_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'paraphrase_dataset',\n",
    "    dataset_fn = paraphrase_dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = [paraphrase_preprocessor],\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [t5.evaluation.metrics.accuracy],\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "nq_task = t5.data.TaskRegistry.get(\"paraphrase_dataset\")\n",
    "ds = nq_task.get_dataset(split='paraphrase.tsv', sequence_length={\"inputs\": 1024, \"targets\": 1024})\n",
    "r = tfds.as_numpy(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'inputs_plaintext': b'parafrasa: Pada 30 Disember 1888, beliau berkahwin dengan Susie J. Clarke di Ayer, Massachusetts.',\n",
       " 'inputs': array([  445,  4435,   722,    31,   206,   480,  1386,   375,  5782,\n",
       "           14,   596,  3887,    28, 24529,    81,   638,     3, 18094,\n",
       "           24, 13999,    14,  2470,     3,     1]),\n",
       " 'targets_plaintext': b'Pada 30 Disember 1888, Susie berkahwin dengan Clarke Brown di Ayer, Massachusetts.',\n",
       " 'targets': array([  206,   480,  1386,   375,  5782,    14, 24529,    81,  3887,\n",
       "           28, 18094,  1659,    24, 13999,    14,  2470,     3,     1])}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(r)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
